import time
from datetime import datetime, timedelta
import json
import os
import traceback
import pymysql
from numpy import record
from tqsdk import TqApi, TqAuth, TqKq
import pandas as pd

import tushare as ts
import logging
from logging.handlers import SMTPHandler
import string
pd.set_option('expand_frame_repr', False)  # 当列太多时不换行
pd.set_option('display.max_rows', 5000)  # 最多显示数据的行数


def ts2KQ(symbol: str) -> str:
    """ tushare"""
    temp = symbol.split('.')
    if 'SHF' in temp[1]:
        return 'SHFE' + '.' + temp[0].lower()
    elif 'CFX' in temp[1]:
        return 'CFFEX' + '.' + temp[0]
    elif 'DCE' in temp[1]:
        return temp[1] + '.' + temp[0].lower()
    elif 'ZCE' in temp[1]:
        for i in temp[0]:
            if i.isdigit():
                temp[0] = ''.join(temp[0].split(i, 1))  # AP2107.ZCE
                break
        return 'CZCE' + '.' + temp[0]
    elif 'INE' in temp[1]:
        return temp[1] + '.' + temp[0].lower()


def init_logger(cfg):
    logger = logging.getLogger('Yhlz')
    # logger.setLevel(logging.DEBUG)
    fm = logging.Formatter(cfg['log']['format'])
    fh = logging.FileHandler(cfg['log']['file'])

    fh.setFormatter(fm)
    fh.setLevel(cfg['log']['file_level'])

    mail_handler = SMTPHandler(
        mailhost=(cfg['log']['mail_host'], cfg['log']['mail_host_port']),
        fromaddr=cfg['log']['fromaddr'],
        toaddrs=cfg['log']['toaddrs'],
        subject=cfg['log']['subject'],
        credentials=(cfg['log']['credential_name'],
                     cfg['log']['credentials_key']))
    mail_handler.setLevel(cfg['log']['mail_level'])
    mail_handler.setFormatter(fm)

    logger.addHandler(mail_handler)
    logger.setLevel(cfg['log']['level'])
    logger.addHandler(fh)
    logger.propagate = False
    return logger


def ini_trade(logger, cfg):
    try:
        api = TqApi(auth=TqAuth(
            cfg['turtle']['usr_name'], cfg['turtle']['passwd']))
        # logger.critical('success sign in! with Q7')
    except Exception as e:
        logger.error('problem with sign in!')
        exit(1)

    tsObj = ts.pro_api(cfg['turtle']['tushare_key'])
    contractInfo = pd.concat(
        [tsObj.fut_basic(exchange='SHFE'), tsObj.fut_basic(exchange='DCE'),
         tsObj.fut_basic(exchange='CZCE'), tsObj.fut_basic(
            exchange='CFFEX'),
         tsObj.fut_basic(exchange='INE')])
    # 只获取当前还没有退市的合约。
    contractInfo = contractInfo[contractInfo['delist_date']
                                > datetime.now().strftime('%Y%m%d')]
    contractInfo = contractInfo.loc[~ contractInfo['symbol'].str.contains(
        'TAS')]
    # contractInfo = contractInfo.loc[contractInfo['symbol'] == 'FU2209', :]  # test
    # contractInfo = contractInfo.iloc[:10, :]  # test 只取前十个
    # tushare原油包含一个 sctas的奇怪合约，将其去掉
    # 这个沥青合约没有上市，tushare数据质量不行。
    contractInfo.drop(
        contractInfo.loc[contractInfo['symbol'] == 'BU2302'].index, inplace=True)

    try:
        conn = pymysql.connect(host=cfg['db']['host'],
                               port=cfg['db']['port'],
                               user=cfg['db']['user'],
                               password=cfg['db']['passwd'],
                               db=cfg['db']['database'])
    except Exception as e:
        logger.error(f'连接数据库出现问题：{e}')
        raise  # 记录日志，重新抛出
    else:
        cur = conn.cursor()

    klines = {}
    for symbol in contractInfo['ts_code']:
        logger.debug(f'{symbol}')
        symbol = ts2KQ(symbol)
        logger.debug(f'订阅k线：{symbol}')
        try:
            klines[symbol] = api.get_kline_serial(symbol, 86400, 60)
            if klines[symbol].loc[0, 'datetime'] == 0:
                # 过滤上市不满60天的合约，取得是60天吗，所以不按55天过滤
                klines.pop(symbol)
                logger.info(
                    f'{symbol} 这个合约上市不满60天，不进行判断！')
        except Exception as e:
            logger.info(f'{symbol} 这个合约订阅k线出现问题！{e}')

    temp = []
    for symbol, kline in klines.items():
        # kline['volume'].iloc[-2] = 0  # test
        if kline['close_oi'].iloc[-2] < cfg['turtle']['min_open_intrest'] or \
                kline['volume'].iloc[-2] < cfg['turtle']['min_volume']:
            temp.append(symbol)
    for key in temp:
        klines.pop(key)

    return api, cur, conn, klines


def stop(now):
    global new_high,new_low

    if not os.path.exists('./temp.json'):
        with open('./temp.json', 'w') as f:
            json.dump({"new_high": [], "new_low": []}, f)

    if not (now.tm_hour == 15 and now.tm_min == 15):
        with open('./temp.json', 'r') as f:  # 用temp来临时充当数据库使用
            record_pos = json.load(f)

        with open('./temp.json', 'w') as f:
            record_pos['new_high'] = new_high
            record_pos['new_low'] = new_low
            json.dump(record_pos, f)

    else:
        with open('./temp.json', 'w') as f:
            json.dump({"new_high": [], "new_low": []}, f)
    
    with open('./temp.json') as f:
        logger.info('temp json内容为：' + f.read())


def check_new(symbol, kline, new_high, new_low):
    logger.debug(f'{symbol}')
    logger.debug(f"{kline['high'].iloc[-1]} :: {kline['high'].iloc[-55: -1].max()}")
    if kline['high'].iloc[-1] > kline['high'].iloc[-55: -1].max():
        if symbol not in new_high:
            new_high.append(symbol)
            logger.info(f'创新高的是：{new_high}')
            stop(now)
            return 'high'

    logger.debug(f"{kline['low'].iloc[-1]} :: {kline['low'].iloc[-55: -1].min()}")
    if kline['low'].iloc[-1] < kline['low'].iloc[-55: -1].min():
        if symbol not in new_low:
            new_low.append(symbol)
            logger.info(f'创新低的是：{new_low}')
            stop(now)
            return 'low'


def check_pos_init():
    add_sended = []
    cut_sended = []
    stop_sended = []
    tick_info = pd.read_csv('./config/general_ticker_info.csv')
    tick_info['symbol'] = tick_info['contract_name'].apply(lambda x: x.split('.')[2])  # 把交易所以及kqm去掉，方便后面比较
    tick_info['symbol'] = tick_info['symbol'].str.lower()
    columns = ['long_short', 'symbol', 'atr', 'N', 'pos_count', 'open_price_1', 'open_price_2', 'open_price_3',
               'open_price_4', 'stopped']
    pos_ = pd.DataFrame()

    def check_pos(kline, pos):
        nonlocal add_sended, cut_sended, pos_, stop_sended
        pos = pd.DataFrame(pos, columns=columns)
        pos.set_index('symbol', inplace=True)
        if not pos_.empty:
            pos[['type', 'sub_type']] = pos_[['type', 'sub_type']]
        # pos.loc['bu2212', 'open_price_1'] = 2000  # test
        # pos.loc['bu2212', 'stopped'] = 0  # test

        if not pos.equals(pos_) or pos_.empty:

            stop_sended = []
            add_sended = []
            cut_sended = []  # 持仓更新了，默认监控者已经处理了所有邮件信息
            for i, row in pos.iterrows():
                temp_df = tick_info.loc[tick_info['symbol'] == i.strip(string.digits), ['type', 'sub_type']]
                if len(temp_df) == 0:
                    logger.error(f'查询内容为空未找到合约{i},\n{tick_info}')
                pos.loc[i, 'type'] = temp_df['type'].iloc[0]  # 数据库无法实现大写的标的代码
                pos.loc[i, 'sub_type'] = temp_df['sub_type'].iloc[0]
            pos_ = pos.copy()
            logger.debug(f'pos is {pos}')
            logger.debug(f'pos_ is {pos_}')

        symbol = kline.loc[0, 'symbol'].split('.')[1].lower()
        logger.debug(f'{symbol}')
        if symbol in pos.index:
            logger.debug(f'{kline}')
            logger.debug(f'{symbol} 有持仓')
            # kline['low'].iloc[-1] = 3855  # test
            if (symbol in add_sended) or (symbol in cut_sended) or (symbol in stop_sended):  # 已经发过的不用再判断。
                return
            row = pos.loc[symbol, :]
            count = row['pos_count']
            logger.debug(f'row {row}')
            logger.debug(f'count {count}')

            open_mark = True
            if count >= 4:  # 单个持仓不超过4
                open_mark = False
            if row['stopped']:  # 标明已经停止加仓，不判断加仓
                open_mark = False
            if pos['pos_count'].sum() >= 24:  # 总计不能超过12
                open_mark = False
            if pos.loc[pos['long_short'] == row['long_short'], 'pos_count'].sum() >= 12:  # 单边不能超过12
                open_mark = False
            if pos.loc[(pos['type'] == row['type']) & (
                    pos['long_short'] == row['long_short']), 'pos_count'].sum() >= 10:  # 松散相关不能超过10
                open_mark = False
            if pos.loc[(pos['sub_type'] == row['sub_type']) & (
                    pos['long_short'] == row['long_short']), 'pos_count'].sum() >= 6:  # 松散相关不能超过6
                open_mark = False
            logger.debug(f'open_mark is {open_mark}')

            # 出场
            if row['long_short'] == 'long':
                if kline['low'].iloc[-1] < kline['low'].iloc[-20: -1].min():
                    stop_sended.append(symbol)
                    return 'stop'
            if row['long_short'] == 'short':
                if kline['high'].iloc[-1] > kline['high'].iloc[-20: -1].max():
                    stop_sended.append(symbol)
                    return 'stop'

            # if count < 4 and row[f'open_price_{count + 1}']:  # 平仓后重新超过价位，再开仓
            #     price = row[f'open_price_{count}']
            #     if row['long_short'] == 'long':
            #         if price > row[f'open_price_{count + 1}'] and open_mark:
            #             add_sended.append(symbol)
            #             return 'add'
            #     elif row['long_short'] == 'short':
            #         if price < row[f'open_price_{count + 1}'] and open_mark:
            #             add_sended.append(symbol)
            #             return 'add'
                    
            if row[f'open_price_{count}']:
                price = row[f'open_price_{count}']
                if row['long_short'] == 'long':
                    logger.debug('long')
                    if not row['stopped'] and open_mark:  # 标明已经停止加仓，不判断加仓
                        if count < 4 and pd.notna(row[f'open_price_{count + 1}']):  # 平仓后重新超过价位，再开仓
                            if row[f'open_price_{count + 1}'] < kline['close'].iloc[-1]:
                                add_sended.append(symbol)
                                return 'add'
                        elif price + row['atr']/2 < kline['close'].iloc[-1]:
                            add_sended.append(symbol)
                            return 'add'
                    if price - row['atr']/2 > kline['close'].iloc[-1]:
                        cut_sended.append(symbol)
                        return 'cut'
                elif row['long_short'] == 'short':
                    logger.debug('short')
                    if not row['stopped'] and open_mark:
                        if count < 4 and pd.notna(row[f'open_price_{count + 1}']):  # 平仓后重新超过价位，再开仓
                            if row[f'open_price_{count + 1}'] > kline['close'].iloc[-1]:
                                add_sended.append(symbol)
                                return 'add'
                        elif price - row['atr']/2 > kline['close'].iloc[-1]:
                            add_sended.append(symbol)
                            return 'add'
                    if price + row['atr']/2 < kline['close'].iloc[-1]:
                        cut_sended.append(symbol)
                        return 'cut'
                else:
                    logger.warning(f'{symbol} long_shout未匹配请检查是否输入错误')
            else:
                logger.error(f'row: {row} count 记录冲突！')

        return

    return check_pos


if __name__ == '__main__':
    new_high = []
    new_low = []
    try:
        with open('./config/turtle_config.json') as f:
            cfg = json.load(f)
            new_high = cfg['turtle']['new_high']
            new_low = cfg['turtle']['new_low']
        
        logger = init_logger(cfg)
        if os.path.exists('./temp.json'):
            with open('./temp.json', 'r') as f:  
                record_pos = json.load(f)
                new_high = record_pos['new_high']
                new_low = record_pos['new_low']
                logger.info(f'temp json 是：{record_pos}')

        api, cur, conn, klines = ini_trade(logger, cfg)
        check_pos = check_pos_init()

        runTick = 0
        # tqa = {}
        # for symbol in klines:
        # tqa[symbol] = pd.DataFrame()
        logger.critical('初始化完成，开始正常运行！')
        # 订阅数据有时候出现奇奇怪怪问题，出问题就报错也不行，需要每天订阅完成后发邮件提示
        while True:
            api.wait_update(time.time() + 20)
            now = time.localtime()
            if now.tm_hour == 20 or now.tm_hour == 8:
                # 夜盘开盘前不要判断，因为此时是用上一天的收盘了的k线判断，是无意义的，
                continue
            if runTick < time.time():
                runTick = time.time() + 5

                sql = f"SELECT * FROM {cfg['db']['table']}"
                cur.execute(sql)
                ret = cur.fetchall()
                conn.commit()  # 不结束事务本地IDE的更新获取不到

                temp_high = []
                temp_low = []
                temp_add = []
                temp_cut = []
                for symbol, kline in klines.items():
                    # 判定最后一根K线的时间是否有变化
                    if api.is_changing(kline.iloc[-1]):
                        res = check_new(symbol, kline, new_high, new_low)
                        if res == 'high':
                            temp_high.append(symbol)
                        elif res == 'low':
                            temp_low.append(symbol)
                        res = ''

                        if ret:
                            logger.debug(f'{symbol}')
                            res = check_pos(kline, ret)
                            logger.debug(f'{res}')
                            if res == 'add':
                                temp_add.append(symbol)
                            if res == 'cut':
                                temp_cut.append(symbol)
                            if res == 'stop':
                                logger.critical(f'{symbol} 需要平仓！')
                        else:
                            logger.warning('ret 为空值')
                if temp_high:
                    logger.critical(f'{temp_high} 出现新高！')
                    time.sleep(0.5)
                if temp_low:
                    logger.critical(f'{temp_low} 出现新低！')
                    time.sleep(0.5)
                if temp_add:
                    logger.critical(f'{temp_add} 需要加仓！')
                    time.sleep(0.5)
                if temp_cut:
                    logger.critical(f'{temp_cut} 需要减仓！')
                    time.sleep(0.5)

            if now.tm_hour == 3 or (now.tm_hour == 15 and now.tm_min == 15):
                logger.critical('停止运行！')
                break

    except Exception as e:
        if 'logger' in dir():
            logger.error('运行出现问题请立即检查！\n' + traceback.format_exc())
        else:
            print('运行出现问题请立即检查！\n' + traceback.format_exc())

    finally:
        if 'conn' in dir():
            conn.close()
        if 'api' in dir():
            api.close()
        stop(time.localtime())
